import numpy as np
import pandas as pd
import os
import pdb
import time
import keras.backend as K
from tqdm import tqdm
from keras.layers import Input
from keras.preprocessing import sequence
from keras.layers import LSTM, Dense, Masking, Concatenate, concatenate, BatchNormalization, Bidirectional
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from keras.models import Sequential, Model
from keras import metrics
from sklearn.model_selection import train_test_split

#Load the training data.

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "1"

start = time.time()
data_path = "/home/furkan/deepXCNV/FREEC/cnvkit_train_samp/"

data_list_readdepths = []
data_list_indexes = []
data_list_chrs = []
data_list_canavar_preds = [] # gt
data_list_cnvkit_preds = [] # cnv 

files_list = os.listdir(data_path)
for filename in files_list:
    with open(data_path + filename, 'r') as f:
        try:
            next(f) # skip first line
        except:
            print("WTF", filename)
            continue
            
        for line in f:
            try:
                chrom, start, end, cn, gt, read_depth = line.split('\t')
                start, end, cn, gt, read_depth = int(start), int(end), int(cn), float(gt), eval(read_depth)
                chrom = int(chrom.replace('chr', ''))

                data_list_readdepths.append(read_depth)
                data_list_indexes.append((start, end,))
                data_list_chrs.append(chrom)
                data_list_canavar_preds.append(gt)
                data_list_cnvkit_preds.append(cn)
            except:
                pass





#convert data lists to numpy arrays
data_list_readdepths = np.asarray(data_list_readdepths)
data_list_indexes = np.asarray(data_list_indexes)
data_list_canavar_preds = np.asarray(data_list_canavar_preds)
data_list_cnvkit_preds = np.asarray(data_list_cnvkit_preds)


data_list_readdepths = sequence.pad_sequences(data_list_readdepths, maxlen= 192000, value = -1)
data_list_readdepths = [np.mean(x.reshape(-1, 100), axis=1) for x in data_list_readdepths]
data_list_readdepths = np.asarray(data_list_readdepths)


pdb.set_trace()


''' 
CNVNATOR PREDS: nan -> 0
                <DUP> -> 1
                <DEL> -> 2
XHMM PREDS: 'DEL' -> 0
            'DUP' -> 1
'''

# pdb.set_trace()
# data_list_cnvnator_preds[data_list_cnvnator_preds == 'nan'] = 0
# data_list_cnvnator_preds[data_list_cnvnator_preds == "'<DUP>'"] = 1
# data_list_cnvnator_preds[data_list_cnvnator_preds == "'<DEL>'"] = 2

pdb.set_trace()

# data_list_xhmm_preds[data_list_xhmm_preds == "'DEL'"] = 0
# data_list_xhmm_preds[data_list_xhmm_preds == "'DUP'"] = 1

# data_list_xhmm_preds = to_categorical(data_list_xhmm_preds, num_classes =2)
# data_list_cnvnator_preds = to_categorical(data_list_cnvnator_preds, num_classes =3)

data_list_readdepths = np.expand_dims(data_list_readdepths, axis=2)
#normalize a bit.


print("Read depths data matrix shape: ", data_list_readdepths.shape)
print("cnvkit predictions data matrix shape: ", data_list_cnvkit_preds.shape)
print("Canavar predictions (labels) data matrix shape: ", data_list_canavar_preds.shape)

'''
input1 <- data_list_xhmm_preds
input2 <- data_list_readdepths
labels <- data_list_cnvnator_preds
'''

#model
max_length =  192000 # maximum length of read depth signals
inpsize = max_length / 100

input1 = Input(shape=(1,)) # cnvkit prediction
input2 = Input(shape=(inpsize,1)) # read depth sequence
masked_input2 = Masking(mask_value = -1)(input2)
features1 = BatchNormalization()(masked_input2)
features2 = Bidirectional(LSTM(128))(features1)
features3 = BatchNormalization()(features2)
merged = concatenate([features3, input1])
features4 = Dense(100, activation='relu')(merged)
output = Dense(1,activation='relu')(features4)

model = Model(inputs=[input1, input2], outputs = output)
print(model.summary())
#comment

#train - test split
data_list_cnvkit_preds_train, data_list_cnvkit_preds_test, \
data_list_readdepths_train, data_list_readdepths_test, \
data_list_canavar_preds_train, data_list_canavar_preds_test = train_test_split(data_list_cnvkit_preds, data_list_readdepths, data_list_canavar_preds, test_size=0.1, random_state=35)

np.save('../outputs/data_list_cnvkit_preds_test.npy', data_list_cnvkit_preds_test)
np.save('../outputs/data_list_readdepthscnvkit_test.npy', data_list_readdepths_test)
np.save('../outputs/data_list_canavarcnvkit_preds_test.npy', data_list_canavar_preds_test)



model.compile(loss='mean_absolute_error', optimizer='adam')
model.fit([data_list_cnvkit_preds_train, data_list_readdepths_train], data_list_canavar_preds_train, validation_split = 0.2, epochs = 60, batch_size=512)

model.save('../outputs/deepXCNVcnvkit_batchnorm_bilstm128_batchnorm_dense100_dense1_bs256_padding-1_60epochs_traintestsplitted_mae.h5')

